from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.nn.functional.conv2d', enabled=trt_version() >= '7.0')
@tensorrt_converter('torch.nn.functional.conv3d', enabled=trt_version() >= '7.0')
def convert_Conv_trt7_functional(ctx):
    input = get_arg(ctx, 'input', pos=0, default=None)
    weight = get_arg(ctx, 'weight', pos=1, default=None)
    bias = get_arg(ctx, 'bias', pos=2, default=None)
    stride = get_arg(ctx, 'stride', pos=3, default=1)
    padding = get_arg(ctx, 'padding', pos=4, default=0)
    dilation = get_arg(ctx, 'dilation', pos=5, default=1)
    groups = get_arg(ctx, 'groups', pos=6, default=1)
    
    input_trt = add_missing_trt_tensors(ctx.network, [input])[0]
    output = ctx.method_return

    input_dim = input.dim() - 2
    
    out_channels = int(weight.shape[0])
    kernel_size = tuple(weight.shape[2:])
    if not isinstance(kernel_size, tuple):
        kernel_size = (kernel_size, ) * input_dim

    if not isinstance(stride, tuple):
        stride = (stride, ) * input_dim

    if not isinstance(padding, tuple):
        padding = (padding, ) * input_dim

    if not isinstance(dilation, tuple):
        dilation = (dilation, ) * input_dim

    kernel = weight.detach().cpu().numpy()
    
    if bias is not None:
        bias = bias.detach().cpu().numpy()

    layer = ctx.network.add_convolution_nd(
        input=input_trt,
        num_output_maps=out_channels,
        kernel_shape=kernel_size,
        kernel=kernel,
        bias=bias)
    layer.stride_nd = stride
    layer.padding_nd = padding
    layer.dilation_nd = dilation

    if groups is not None:
        layer.num_groups = groups

    output._trt = layer.get_output(0)


class FunctionalConv2d(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = torch.nn.Conv2d(*args, **kwargs)

    def forward(self, x):
        x = torch.nn.functional.conv2d(
            x, 
            self.conv.weight,
            self.conv.bias,
            self.conv.stride,
            self.conv.padding,
            self.conv.dilation,
            self.conv.groups
        )
        return x

class FunctionalConv3d(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.conv = torch.nn.Conv3d(*args, **kwargs)

    def forward(self, x):
        x = torch.nn.functional.conv3d(
            x, 
            self.conv.weight,
            self.conv.bias,
            self.conv.stride,
            self.conv.padding,
            self.conv.dilation,
            self.conv.groups
        )
        return x

@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_basic_trt7_functional():
    return FunctionalConv2d(10, 5, kernel_size=1, stride=1, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_stride2_trt7_functional():
    return FunctionalConv2d(10, 5, kernel_size=1, stride=2, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_kernel3_trt7_functional():
    return FunctionalConv2d(10, 5, kernel_size=3, stride=2, padding=1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 224, 224)], enabled=trt_version() >= '7.0')
def test_Conv2d_dilation2_trt7_functional():
    return FunctionalConv2d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_basic_trt7_functional():
    return FunctionalConv3d(10, 5, kernel_size=1, stride=1, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_stride2_trt7_functional():
    return FunctionalConv3d(10, 5, kernel_size=1, stride=2, padding=0)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_kernel3_trt7_functional():
    return FunctionalConv3d(10, 5, kernel_size=3, stride=2, padding=1)


@add_module_test(torch.float32, torch.device('cuda'), [(1, 10, 64, 64, 64)], enabled=trt_version() >= '7.0')
def test_Conv3d_dilation2_trt7_functional():
    return FunctionalConv3d(10, 5, kernel_size=3, stride=1, padding=1, dilation=2)
